#
#  Copyright (C) 2004-2017 Rational Discovery LLC
#
#   @@ All Rights Reserved @@
#  This file is part of the RDKit.
#  The contents are covered by the terms of the BSD license
#  which is included in the file license.txt, found at the root
#  of the RDKit source tree.
#
from __future__ import print_function
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
import math


def ExplainAtomCode(code, branchSubtract=0, includeChirality=False):
  """

  **Arguments**:

    - the code to be considered

    - branchSubtract: (optional) the constant that was subtracted off
      the number of neighbors before integrating it into the code.
      This is used by the topological torsions code.

    - includeChirality: (optional) Determines whether or not chirality
      was included when generating the atom code.

  >>> m = Chem.MolFromSmiles('C=CC(=O)O')
  >>> code = GetAtomCode(m.GetAtomWithIdx(0))
  >>> ExplainAtomCode(code)
  ('C', 1, 1)
  >>> code = GetAtomCode(m.GetAtomWithIdx(1))
  >>> ExplainAtomCode(code)
  ('C', 2, 1)
  >>> code = GetAtomCode(m.GetAtomWithIdx(2))
  >>> ExplainAtomCode(code)
  ('C', 3, 1)
  >>> code = GetAtomCode(m.GetAtomWithIdx(3))
  >>> ExplainAtomCode(code)
  ('O', 1, 1)
  >>> code = GetAtomCode(m.GetAtomWithIdx(4))
  >>> ExplainAtomCode(code)
  ('O', 1, 0)

  we can do chirality too, that returns an extra element in the tuple:
  >>> m = Chem.MolFromSmiles('C[C@H](F)Cl')
  >>> code = GetAtomCode(m.GetAtomWithIdx(1))
  >>> ExplainAtomCode(code)
  ('C', 3, 0)
  >>> code = GetAtomCode(m.GetAtomWithIdx(1),includeChirality=True)
  >>> ExplainAtomCode(code,includeChirality=True)
  ('C', 3, 0, 'R')

  note that if we don't ask for chirality, we get the right answer even if
  the atom code was calculated with chirality:
  >>> ExplainAtomCode(code)
  ('C', 3, 0)

  non-chiral atoms return '' in the 4th field:
  >>> code = GetAtomCode(m.GetAtomWithIdx(0),includeChirality=True)
  >>> ExplainAtomCode(code,includeChirality=True)
  ('C', 1, 0, '')

  Obviously switching the chirality changes the results:
  >>> m = Chem.MolFromSmiles('C[C@@H](F)Cl')
  >>> code = GetAtomCode(m.GetAtomWithIdx(1),includeChirality=True)
  >>> ExplainAtomCode(code,includeChirality=True)
  ('C', 3, 0, 'S')

  """
  typeMask = (1 << rdMolDescriptors.AtomPairsParameters.numTypeBits) - 1
  branchMask = (1 << rdMolDescriptors.AtomPairsParameters.numBranchBits) - 1
  piMask = (1 << rdMolDescriptors.AtomPairsParameters.numPiBits) - 1
  chiMask = (1 << rdMolDescriptors.AtomPairsParameters.numChiralBits) - 1

  nBranch = int(code & branchMask)
  code = code >> rdMolDescriptors.AtomPairsParameters.numBranchBits

  nPi = int(code & piMask)
  code = code >> rdMolDescriptors.AtomPairsParameters.numPiBits

  typeIdx = int(code & typeMask)
  if typeIdx < len(rdMolDescriptors.AtomPairsParameters.atomTypes):
    atomNum = rdMolDescriptors.AtomPairsParameters.atomTypes[typeIdx]
    atomSymbol = Chem.GetPeriodicTable().GetElementSymbol(atomNum)
  else:
    atomSymbol = 'X'

  if not includeChirality:
      return (atomSymbol, nBranch, nPi)
  else:
      code = code >> rdMolDescriptors.AtomPairsParameters.numTypeBits
      chiDict = {0:'',1:'R',2:'S'}
      chiCode = int(code & chiMask)
      return (atomSymbol, nBranch, nPi, chiDict[chiCode])

GetAtomCode = rdMolDescriptors.GetAtomPairAtomCode

def NumPiElectrons(atom):
  """ Returns the number of electrons an atom is using for pi bonding

  >>> m = Chem.MolFromSmiles('C=C')
  >>> NumPiElectrons(m.GetAtomWithIdx(0))
  1

  >>> m = Chem.MolFromSmiles('C#CC')
  >>> NumPiElectrons(m.GetAtomWithIdx(0))
  2
  >>> NumPiElectrons(m.GetAtomWithIdx(1))
  2

  >>> m = Chem.MolFromSmiles('O=C=CC')
  >>> NumPiElectrons(m.GetAtomWithIdx(0))
  1
  >>> NumPiElectrons(m.GetAtomWithIdx(1))
  2
  >>> NumPiElectrons(m.GetAtomWithIdx(2))
  1
  >>> NumPiElectrons(m.GetAtomWithIdx(3))
  0

  >>> m = Chem.MolFromSmiles('c1ccccc1')
  >>> NumPiElectrons(m.GetAtomWithIdx(0))
  1

  FIX: this behaves oddly in these cases:
  >>> m = Chem.MolFromSmiles('S(=O)(=O)')
  >>> NumPiElectrons(m.GetAtomWithIdx(0))
  2

  >>> m = Chem.MolFromSmiles('S(=O)(=O)(O)O')
  >>> NumPiElectrons(m.GetAtomWithIdx(0))
  0

  In the second case, the S atom is tagged as sp3 hybridized.

  """

  res = 0
  if atom.GetIsAromatic():
    res = 1
  elif atom.GetHybridization() != Chem.HybridizationType.SP3:
    # the number of pi electrons is just the number of
    # unsaturations (valence - degree):
    res = atom.GetExplicitValence() - atom.GetNumExplicitHs()
    if res < atom.GetDegree():
      raise ValueError("explicit valence exceeds atom degree")
    res -= atom.GetDegree()
  return res


def BitsInCommon(v1, v2):
  """ Returns the number of bits in common between two vectors

  **Arguments**:

    - two vectors (sequences of bit ids)

  **Returns**: an integer

  **Notes**

    - the vectors must be sorted

    - duplicate bit IDs are counted more than once

  >>> BitsInCommon( (1,2,3,4,10), (2,4,6) )
  2

  Here's how duplicates are handled:
  >>> BitsInCommon( (1,2,2,3,4), (2,2,4,5,6) )
  3

  """
  res = 0
  v2Pos = 0
  nV2 = len(v2)
  for val in v1:
    while v2Pos < nV2 and v2[v2Pos] < val:
      v2Pos += 1
    if v2Pos >= nV2:
      break
    if v2[v2Pos] == val:
      res += 1
      v2Pos += 1
  return res


def DiceSimilarity(v1, v2, bounds=None):
  """ Implements the DICE similarity metric.
   This is the recommended metric in both the Topological torsions
   and Atom pairs papers.

  **Arguments**:

    - two vectors (sequences of bit ids)

  **Returns**: a float.

  **Notes**

    - the vectors must be sorted


  >>> DiceSimilarity( (1,2,3), (1,2,3) )
  1.0
  >>> DiceSimilarity( (1,2,3), (5,6) )
  0.0
  >>> DiceSimilarity( (1,2,3,4), (1,3,5,7) )
  0.5
  >>> DiceSimilarity( (1,2,3,4,5,6), (1,3) )
  0.5

  Note that duplicate bit IDs count multiple times:
  >>> DiceSimilarity( (1,1,3,4,5,6), (1,1) )
  0.5

  but only if they are duplicated in both vectors:
  >>> DiceSimilarity( (1,1,3,4,5,6), (1,) )==2./7
  True

  edge case
  >>> DiceSimilarity( (), () )
  0.0

  and bounds check
  >>> DiceSimilarity( (1,1,3,4), (1,1))
  0.666...
  >>> DiceSimilarity( (1,1,3,4), (1,1), bounds=0.3)
  0.666...
  >>> DiceSimilarity( (1,1,3,4), (1,1), bounds=0.33)
  0.666...
  >>> DiceSimilarity( (1,1,3,4,5,6), (1,1), bounds=0.34)
  0.0

  """
  denom = 1.0 * (len(v1) + len(v2))
  if not denom:
    res = 0.0
  else:
    if bounds and (min(len(v1), len(v2)) / denom) < bounds:
      numer = 0.0
    else:
      numer = 2.0 * BitsInCommon(v1, v2)
    res = numer / denom
  return res


def Dot(v1, v2):
  """ Returns the Dot product between two vectors:

  **Arguments**:

    - two vectors (sequences of bit ids)

  **Returns**: an integer

  **Notes**

    - the vectors must be sorted

    - duplicate bit IDs are counted more than once

  >>> Dot( (1,2,3,4,10), (2,4,6) )
  2

  Here's how duplicates are handled:
  >>> Dot( (1,2,2,3,4), (2,2,4,5,6) )
  5
  >>> Dot( (1,2,2,3,4), (2,4,5,6) )
  2
  >>> Dot( (1,2,2,3,4), (5,6) )
  0
  >>> Dot( (), (5,6) )
  0

  """
  res = 0
  nV1 = len(v1)
  nV2 = len(v2)
  i = 0
  j = 0
  while i < nV1:
    v1Val = v1[i]
    v1Count = 1
    i += 1
    while i < nV1 and v1[i] == v1Val:
      v1Count += 1
      i += 1
    while j < nV2 and v2[j] < v1Val:
      j += 1
    if j < nV2 and v2[j] == v1Val:
      v2Count = 1
      j += 1
      while j < nV2 and v2[j] == v1Val:
        v2Count += 1
        j += 1
      commonCount = min(v1Count, v2Count)
      res += commonCount * commonCount
    elif j >= nV2:
      break
  return res


def CosineSimilarity(v1, v2):
  """ Implements the Cosine similarity metric.
   This is the recommended metric in the LaSSI paper

  **Arguments**:

    - two vectors (sequences of bit ids)

  **Returns**: a float.

  **Notes**

    - the vectors must be sorted

  >>> print('%.3f'%CosineSimilarity( (1,2,3,4,10), (2,4,6) ))
  0.516
  >>> print('%.3f'%CosineSimilarity( (1,2,2,3,4), (2,2,4,5,6) ))
  0.714
  >>> print('%.3f'%CosineSimilarity( (1,2,2,3,4), (1,2,2,3,4) ))
  1.000
  >>> print('%.3f'%CosineSimilarity( (1,2,2,3,4), (5,6,7) ))
  0.000
  >>> print('%.3f'%CosineSimilarity( (1,2,2,3,4), () ))
  0.000

  """
  d1 = Dot(v1, v1)
  d2 = Dot(v2, v2)
  denom = math.sqrt(d1 * d2)
  if not denom:
    res = 0.0
  else:
    numer = Dot(v1, v2)
    res = numer / denom
  return res


# ------------------------------------
#
#  doctest boilerplate
#
def _runDoctests(verbose=None):  # pragma: nocover
  import sys
  import doctest
  failed, _ = doctest.testmod(optionflags=doctest.ELLIPSIS, verbose=verbose)
  sys.exit(failed)


if __name__ == '__main__':  # pragma: nocover
  _runDoctests()
